{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %set_env CUDA_VISIBLE_DEVICES=7\n",
    "# import sys; sys.path.append('/future/u/okhattab/repos/public/stanfordnlp/dspy')\n",
    "\n",
    "import dspy\n",
    "from dspy.evaluate import Evaluate\n",
    "from dspy.datasets.hotpotqa import HotPotQA\n",
    "from dspy.teleprompt import BootstrapFewShotWithRandomSearch, BootstrapFinetune"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Configure the default LM and retriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ports = [7140, 7141, 7142, 7143, 7144, 7145]\n",
    "llamaChat = dspy.HFClientTGI(model=\"meta-llama/Llama-2-13b-chat-hf\", port=ports, max_tokens=150)\n",
    "colbertv2 = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
    "\n",
    "dspy.settings.configure(rm=colbertv2, lm=llamaChat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Load a small sample of HotPotQA data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200, 1000, 0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = HotPotQA(train_seed=1, train_size=200, eval_seed=2023, dev_size=1000, test_size=0)\n",
    "trainset = [x.with_inputs('question') for x in dataset.train]\n",
    "devset = [x.with_inputs('question') for x in dataset.dev]\n",
    "testset = [x.with_inputs('question') for x in dataset.test]\n",
    "\n",
    "len(trainset), len(devset), len(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Example({'question': 'At My Window was released by which American singer-songwriter?', 'answer': 'John Townes Van Zandt'}) (input_keys={'question'})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Define a simple multi-hop program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dsp.utils.utils import deduplicate\n",
    "\n",
    "class BasicMH(dspy.Module):\n",
    "    def __init__(self, passages_per_hop=3):\n",
    "        super().__init__()\n",
    "\n",
    "        self.retrieve = dspy.Retrieve(k=passages_per_hop)\n",
    "        self.generate_query = [dspy.ChainOfThought(\"context, question -> search_query\") for _ in range(2)]\n",
    "        self.generate_answer = dspy.ChainOfThought(\"context, question -> answer\")\n",
    "    \n",
    "    def forward(self, question):\n",
    "        context = []\n",
    "        \n",
    "        for hop in range(2):\n",
    "            search_query = self.generate_query[hop](context=context, question=question).search_query\n",
    "            passages = self.retrieve(search_query).passages\n",
    "            context = deduplicate(context + passages)\n",
    "\n",
    "        return self.generate_answer(context=context, question=question).copy(context=context)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) Compile the program with `Llama2-13b-chat`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "RECOMPILE_INTO_LLAMA_FROM_SCRATCH = False\n",
    "NUM_THREADS = 24\n",
    "\n",
    "metric_EM = dspy.evaluate.answer_exact_match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RECOMPILE_INTO_LLAMA_FROM_SCRATCH:\n",
    "    tp = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_threads=NUM_THREADS)\n",
    "    basicmh_bs = tp.compile(BasicMH(), trainset=trainset[:50], valset=trainset[50:200])\n",
    "\n",
    "    ensemble = [prog for *_, prog in basicmh_bs.candidate_programs[:4]]\n",
    "\n",
    "    for idx, prog in enumerate(ensemble):\n",
    "        # prog.save(f'checkpoints/multihop_llama213b_{idx}.json')\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RECOMPILE_INTO_LLAMA_FROM_SCRATCH:\n",
    "    ensemble = []\n",
    "\n",
    "    for idx in range(4):\n",
    "        prog = BasicMH()\n",
    "        prog.load(f'checkpoints/multihop_llama213b_{idx}.json')\n",
    "        ensemble.append(prog)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Metric: 424 / 1000  (42.4): 100%|██████████| 1000/1000 [00:14<00:00, 70.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Metric: 424 / 1000  (42.4%)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "42.4"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llama_program = ensemble[0]\n",
    "\n",
    "evaluate_hotpot = Evaluate(devset=devset[:1000], metric=metric_EM, num_threads=NUM_THREADS, display_progress=True, display_table=0)\n",
    "evaluate_hotpot(llama_program)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_program(question=\"How many storeys are in the castle that David Gregory inherited?\")\n",
    "\n",
    "llamaChat.inspect_history(n=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6) Compile into `T5-Large` (770M parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3000"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unlabeled_train = HotPotQA(train_seed=1, train_size=3000, eval_seed=2023, dev_size=0, test_size=0).train\n",
    "unlabeled_train = [dspy.Example(question=x.question).with_inputs('question') for x in unlabeled_train]\n",
    "len(unlabeled_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optional step: pre-compute the ensemble on the unlabeled training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "always_true = lambda g, p, trace=None: True\n",
    "\n",
    "for prog_ in ensemble:\n",
    "    evaluate_hotpot(prog_, devset=unlabeled_train[:3000], metric=always_true)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now compile into T5!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "RECOMPILE_INTO_T5_FROM_SCRATCH = False\n",
    "\n",
    "if RECOMPILE_INTO_T5_FROM_SCRATCH:\n",
    "    config = dict(target='t5-large', epochs=2, bf16=True, bsize=6, accumsteps=2, lr=5e-5)\n",
    "\n",
    "    tp = BootstrapFinetune(metric=None)\n",
    "    t5_program = tp.compile(BasicMH(), teacher=ensemble, trainset=unlabeled_train[:3000], **config)\n",
    "\n",
    "    # Deactivate chain of thought prompting. Let's use T5 to directly predict outputs. (Faster and similar quality.)\n",
    "    for p in t5_program.predictors(): p.activated = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not RECOMPILE_INTO_T5_FROM_SCRATCH:\n",
    "    t5_program = BasicMH()\n",
    "\n",
    "    # ckpt_path = '../finetuning_ckpts/LMWEP0WZ5IKWM.all/checkpoint-5400'\n",
    "    ckpt_path = \"colbert-ir/dspy-Oct11-T5-Large-MH-3k-v1\"\n",
    "    LM = dspy.HFModel(checkpoint=ckpt_path, model='t5-large')\n",
    "\n",
    "    for p in t5_program.predictors():\n",
    "        p.lm = LM\n",
    "        p.activated = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7) Evaluate the T5-Large `multihop` program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score = evaluate_hotpot(t5_program, num_threads=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t5_program.predictors()[0].lm.inspect_history(n=3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py39_aug2023_dspy",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
